from datetime import datetime
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.io as pio
import geopandas as gpd
import json
from ipywidgets import interactive, HBox, VBox, interact, widgets
from IPython.display import IFrame
pio.templates['rockwell'] = go.layout.Template(
layout=go.Layout(
font=dict(
family='Rockwell',
size=18,
color='#2A5674'
),
title={
'yanchor': 'middle'},
)
)
pio.templates.default = "plotly_white+rockwell"
Подготовим данные для визуализации - реальные каунты и прогнозы.
july = pd.date_range(datetime(2018,6,30,23), datetime(2018,7,31,23), freq='H')
# реальные каунты за июль
july_real_counts = pd.read_csv('NYC_TAXI_aggregated_data/pu_agg_data_pop.csv',
index_col=[0], parse_dates=[0]).loc[july[1:]]
july_real_counts.head()
# прогнозы на июль
temp_july_forecasts = pd.read_csv('XGBoost/july_forecasts.csv', index_col=[0], parse_dates=[0])
temp_july_forecasts['zone'] = temp_july_forecasts['zone'].astype(str)
temp_july_forecasts.sample(5)
Приведем прогнозы к более удобному формату для визуалиазации: разобьем по зонам и усредним прогнозы по моделям.
zones = july_real_counts.columns
july_forecasts = pd.DataFrame(index=july)
for zone in zones:
temp = pd.DataFrame(index=july)
for shift in range(1,7):
prediction = temp_july_forecasts[(temp_july_forecasts['zone']==zone) &
(temp_july_forecasts['shift']==shift)][['prediction']]
temp['step_{}'.format(shift)] = prediction
for shift in range(1,7):
temp['step_{}'.format(shift)] = temp['step_{}'.format(shift)].shift(shift)
july_forecasts[zone] = temp.mean(axis=1)
july_forecasts.columns.name = 'zone_id'
july_forecasts.index.name = 'pickup_datetime'
july_forecasts = july_forecasts.dropna()
july_forecasts.head()
Получили привычные временные ряды по каждой зоне.
del temp_july_forecasts
# функция для слайдера
def update_zone(zone_id):
timeseries_plot.data[0].y = july_real_counts[zone_id]
timeseries_plot.data[1].y = july_forecasts[zone_id]
# сам график
trace_real = go.Scatter(x = july,
y = july_real_counts['4'],
name = 'Real counts',
line=dict(width=2, color='#6785be')
)
trace_predicted = go.Scatter(x = july,
y = july_forecasts['4'],
name = 'Predicted counts',
line=dict(width=2, dash='dot', color='#ba6657')
)
data = [trace_real, trace_predicted]
layout = dict(title = 'Real and predicted timerows of taxi zones (JULY 2018)',
xaxis = dict(rangeslider = dict(visible = True),
type = 'date'),
font = dict(family = 'Rockwell'),
height=600,
margin = dict(l = 20, r = 10, b = 0, t = 50, pad = 0),
)
timeseries_plot = go.FigureWidget(data = data, layout = layout)
# виджет для выбора зоны
choose_zone_drop_down = interactive(update_zone,
zone_id = widgets.Dropdown(options=zones,
value='4',
description='chooze zone_id',
disabled=False))
VBox([timeseries_plot, choose_zone_drop_down])
IFrame('exp/real_and_predicted_timerows.gif', 1000, 700, unconfined=True)
zones_gdf = gpd.read_file('NYC_TAXI_data/other_data/taxi_zones.geojson').loc[:,['zone', 'OBJECTID', 'borough','geometry']]
zones_gdf.columns = ['zone_name', 'zone_id', 'borough','geometry']
zones_gdf['zone_id'] = zones_gdf['zone_id'].astype(str)
NY_center_lat = (40.49612+40.91553)/2
NY_center_lon = (-74.25559-73.70001)/2
zones_gdf_with_forecast = zones_gdf.merge(july_forecasts.T.reset_index(), how='left', on='zone_id')
zones_gdf_with_forecast.columns = zones_gdf_with_forecast.columns.astype(str)
zones_gdf_with_forecast.sample(2)
# функция смены времени на карте
def update_datetime(dt):
map_plot.data[0].z = zones_gdf_with_forecast[str(dt)]
zones_geojson = json.loads(zones_gdf_with_forecast.to_json())
# карта
data = go.Choroplethmapbox(geojson=zones_geojson,
locations=zones_gdf_with_forecast['zone_id'],
z=zones_gdf_with_forecast[str(july[1])],
hovertext=zones_gdf_with_forecast['zone_name'],
hovertemplate='<b>Zone name</b>: <b>%{hovertext}</b>'+
'<br><b>Zone ID </b>: %{location}'+
"<extra></extra>",
showlegend=False,
autocolorscale=False,
colorscale='Viridis',
showscale=True,
marker_opacity=0.8, marker_line_width=0.1 )
layout = go.Layout(mapbox_style='carto-positron',
mapbox_zoom=9,
mapbox_center = {'lat': NY_center_lat, 'lon': NY_center_lon},
hoverlabel=dict(bgcolor="white", font_size=12, font_family='Rockwell'),
margin={"r":0,"t":100,"l":0,"b":0},
title='Most popular New York taxi zones (july 2018)<br>'+
'(colored by predicted trip count)',
height=600, width=600)
map_plot = go.FigureWidget(data=data, layout=layout)
# слайдер времени
datetime_slider = interactive(update_datetime,
dt = widgets.SelectionSlider(options = july[1:24*7+1],
description = ' ',
layout=widgets.Layout(width='600px'),
style = {'description_width': 'initial'}
))
VBox([map_plot, datetime_slider])
IFrame('exp/map_with_predictions.gif', 1000, 650, unconfined=True)